#include <stdio.h>
#include <math.h>
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#include <unistd.h>
#include <string>
#include <sys/sem.h>
#include <fcntl.h>
#include <signal.h>
#include <sstream>
//REF: https://stackoverflow.com/questions/2279052/increase-stack-size-in-linux-with-setrlimit/2279084#2279084
#include <sys/resource.h>
#include <cmath>
#include <random>
using namespace std;

class CSHM { 
	private :
		int max_size;
		int m_shmid;   
		key_t m_key;
		char *m_shared_memory;
	public : 
		char read_data[820000];
		int getShmId();
		void setKey(key_t key); 
		void setMem(int permission, int r_size);
		void writeMem(string str);
		void readMem();
		void close();
};
void CSHM::setKey(key_t key) {
    m_key = key;
}
void CSHM::setMem(int permission, int r_size) {
	max_size = r_size;
	if ((m_shmid = shmget(m_key, max_size, IPC_CREAT | permission)) < 0) {
		perror("shmget failed ");
		exit(1);
   	}
	if ((m_shared_memory = (char*)(shmat(m_shmid,NULL,0))) == (char *)-1) {
      perror("shmat failed ");
	  shmctl(m_shmid , IPC_RMID, NULL); 
      exit(1);
   	}
}
void CSHM::writeMem(string str) {
	memcpy(m_shared_memory, str.c_str() , str.size());
}
void CSHM::readMem() {
	memcpy(read_data, m_shared_memory, max_size);
}
void CSHM::close() {
	sleep(3); 
	void* shmdt(void *m_shmid);
	shmctl(m_shmid , IPC_RMID, NULL); 
}

#define SEM_RESOURCE_MAX 1
#define SEM_LOCK_1 {0, -1, SEM_UNDO} 
#define SEM_UNLOCK_1 {0, 1, SEM_UNDO}
#define SEM_LOCK_2 {0, -1, IPC_NOWAIT} 
#define SEM_UNLOCK_2 {0, 1, IPC_NOWAIT}
union semun {
    int val;             
    struct semid_ds *buf;     
    unsigned short int *array; 
	struct seminfo  *__buf;
};

static inline int
sem_init(int *semid, key_t key) {
    if((*semid = semget(key, 1, IPC_CREAT|IPC_EXCL|0606)) == -1) {
        perror("semget failed:");
		exit(1);
    }
    union semun semopts;
    semopts.val = SEM_RESOURCE_MAX;
    semctl(*semid, 0, SETVAL, semopts);
    return 0;
}

static inline void
sem_wait(int *semid) {
	struct sembuf sem_lock = SEM_LOCK_2;
	while (semop(*semid, &sem_lock, 1) == -1) {}
}

static inline void
sem_post(int *semid) {
	struct sembuf sem_lock = SEM_UNLOCK_2;
	while (semop(*semid, &sem_lock, 1) == -1) {}
}

static inline void
sem_destroy(int *semid) {
    semctl(*semid, 0, IPC_RMID, 0);
}

// global variable section 1
CSHM pof_shared_memory_1;
CSHM pof_shared_memory_2;
int sem_1;
int sem_2;
key_t key_1;
key_t key_2;
int permission_1 = 0602;
int permission_2 = 0604;
const int n_NN = 1;
const int n_sample_of_action = 252;
const int n_rank = 2;
const double lambda = -1;
const double xi_start = 0;
const double xi_end = 10.0 / 500; 
const int total_epoch = 1000; 
const double sqrt_xi_interval = (sqrt(xi_end) - sqrt(xi_start))/(total_epoch*1000);
const int checkpoint = 0;
double default_min_loss = 30.0;
const int n_batch = 1;
const int const_one_label_intest_size_per_sample = 5000;
const int pof_section = 3;

// global variable section 2
int mode = -10000;
int default_action_batch[n_batch] = {};
double mu_batch[n_batch][2] = {};
double sigma_batch[n_batch][2] = {};
double logp_batch[n_batch][n_sample_of_action] = {};
int n_class = -10000;
double ccon_w_batch[n_NN][n_batch][n_sample_of_action][5] = {};
int n_test_case[n_NN][2];
int ans_batch[n_NN][n_batch][const_one_label_intest_size_per_sample*n_rank] = {};
double test_case_ccon_w_batch[n_NN][n_batch][const_one_label_intest_size_per_sample*n_rank][5] = {};
int act_batch[n_batch] = {};
double grad_ccon_batch[n_NN][n_batch][n_sample_of_action][5] = {};
double grad_test_case_ccon_batch[n_NN][n_batch][const_one_label_intest_size_per_sample*n_rank][5] = {};
double xi = xi_start;
double th = 0.0001;
double beta = 3.;

// global variable section 3-1
int default_action = -10000;
double mu[2] = {}, sigma[2]={};
double ccon_w[n_NN][n_sample_of_action][5] = {};
int ans[n_NN][const_one_label_intest_size_per_sample*n_rank]= {};
int ans_sum[n_NN][2] = {};
double test_case_ccon_w[n_NN][const_one_label_intest_size_per_sample*n_rank][5] = {};
int act = -10000;
double main_approxloss;
double grad_ccon[n_NN][n_sample_of_action][5] = {};
double grad_test_case_ccon[n_NN][const_one_label_intest_size_per_sample*n_rank][5] = {};

void signal_callback_handler(int signum) {
	cout << "Caught signal " << signum << endl;
	sem_destroy(&sem_1);
	sem_destroy(&sem_2);
	cout << "DESTROY S1 & S2" << endl;
	pof_shared_memory_1.close();
	pof_shared_memory_2.close();
	cout << "DESTROY SHM1 & SHM2" << endl;
	exit(signum);
}

void pof_output_reading() {
	string shared_string(pof_shared_memory_1.read_data);
	string temp;
	stringstream pof_output(shared_string);
	getline(pof_output, temp);
	mode = stoi(temp);
	getline(pof_output, temp);
	int arrival_intest_size = stoi(temp);
	for (int i = 0; i < n_batch; i++) {
		getline(pof_output, temp);
		default_action_batch[i] = stoi(temp);
	}
	for (int i = 0; i < n_batch; i++) {
		getline(pof_output, temp);
		mu_batch[i][0] = stod(temp);
		getline(pof_output, temp);
		mu_batch[i][1] = stod(temp);
		getline(pof_output, temp);
		sigma_batch[i][0] = stod(temp);
		getline(pof_output, temp);
		sigma_batch[i][1] = stod(temp);
	}
	getline(pof_output, temp);
	n_class = stoi(temp);
	for (int i = 0; i < n_NN; i++) {
		for (int j = 0; j < n_batch; j++) {
			for (int k = 0; k < n_sample_of_action; k++) {
				for (int m = 0; m < n_class; m++) {
					getline(pof_output, temp);
					ccon_w_batch[i][j][k][m] = max(-100., min(stod(temp), 100.));
				}
			}
		}
	}
	for (int i = 0; i < n_NN; i++) {
		for (int j = 0; j < n_rank; j++) {
			getline(pof_output, temp);
			n_test_case[i][j] = stoi(temp);
		}
	}
	int intest_unit_number_1 = int(arrival_intest_size*0.5);
	int intest_unit_number_2 = int(arrival_intest_size*0.5);
	for (int i = 0; i < n_NN; i++) {
		for (int j = 0; j < n_rank; j++) {
			for (int k = 0; k < n_batch; k++) {
				for (int m = 0; m < const_one_label_intest_size_per_sample; m++) {
					getline(pof_output, temp);
					ans_batch[i][k][j*const_one_label_intest_size_per_sample + m] = stoi(temp);
					for (int n = 0; n < n_class; n++) {
						getline(pof_output, temp);
						test_case_ccon_w_batch[i][k][j*const_one_label_intest_size_per_sample + m][n] = max(-100., min(stod(temp), 100.));
					}
				}
			}
		}
	}
}


void pof_update_writing() {
	string shared_string = "";
	for (int i = 0; i < n_batch; i++) {
		shared_string += to_string(act_batch[i]);
    	shared_string += "\n";
	}
	if (mode == 0) {
		for (int i = 0; i < n_NN; i++) {
			for (int j = 0; j < n_batch; j++) {
				for (int k = 0; k < n_sample_of_action; k++) {
					for (int m = 0; m < n_class; m++) {
						shared_string += to_string(grad_ccon_batch[i][j][k][m]);
						shared_string += "\n";
					}
				}
			}
		}
		for (int i = 0; i < n_NN; i++) {
			for (int j = 0; j < n_batch; j++) {
				for (int k = 0; k < const_one_label_intest_size_per_sample*n_rank; k++) {
					for (int m = 0; m < n_class; m++) {
						shared_string += to_string(grad_test_case_ccon_batch[i][j][k][m]);
						shared_string += "\n";
					}
				}
			}
		}
	}
	pof_shared_memory_2.writeMem(shared_string);
}


class table {
	public:
		int results[n_NN][2][5][2] = {};
		void reset() { memset(results, 0, sizeof(results)); }
		void reset(int nn_order) { memset(results[nn_order], 0, sizeof(results[nn_order])); }
};
class posterior {
	public:
		double post_prob[n_NN][5][2] = {};
};
class classassign {
	public:
		int assign[n_NN][n_sample_of_action];
		void reset() { memset(assign, 0, sizeof(assign)); }
};

// global variable section 3-2
table normaltable;
classassign normalclass;
int normt_eff[n_NN][const_one_label_intest_size_per_sample*n_rank][5][2] = {};


void calc_normaltable() {
	for (int i = 0; i < n_NN; i++) {
		int n_test_case_pm = const_one_label_intest_size_per_sample*n_rank;
		for (int j = 0; j < n_test_case_pm; j++) {
			double max = -5000000, secmax = -5000000;
			for (int k = 0; k < n_class; k++) {
				if (test_case_ccon_w[i][j][k] > max) {
					secmax = max;
					max = test_case_ccon_w[i][j][k];
				}
				else secmax = secmax > test_case_ccon_w[i][j][k] ? secmax : test_case_ccon_w[i][j][k];
			}
			for (int k = 0; k < n_class; k++) {
				if (test_case_ccon_w[i][j][k] + xi > max) {
					normaltable.results[i][ans[i][j]][k][1]++;
					normt_eff[i][j][k][1] = 1;
				}
				if (test_case_ccon_w[i][j][k] - xi > secmax) {
					normaltable.results[i][ans[i][j]][k][0]++;
					normt_eff[i][j][k][0] = 1;
				}
			}
		}
	}
}


posterior calculate_posterior(table t)  {
	posterior p;
	for (int i = 0; i < n_NN; i++) {
		double prior_denom = 0;
		for (int k = 0; k < n_rank; k++) prior_denom += n_test_case[i][k];
		for (int j = 0; j < n_class; j++) {
			double denominator = 0.00000001;
			for (int k = 0; k < n_rank; k++) denominator += ((double)n_test_case[i][k] / prior_denom) * t.results[i][k][j][0] / (ans_sum[i][k] + 0.00000001);
			for (int k = 0; k < n_rank; k++) p.post_prob[i][j][k] = ((((double)n_test_case[i][k]/prior_denom) * t.results[i][k][j][1]) / (ans_sum[i][k] + 0.00000001) + 0.00000001) / denominator;
		}
	}
	return p;
}


void calc_normalclass() {
	for (int i = 0; i < n_NN; i++) {
		for (int j = 0; j < n_sample_of_action; j++) {
			double max = -5000000;
			for (int k = 0; k < n_class; k++) {
				if (ccon_w[i][j][k] > max) {
					normalclass.assign[i][j] = k;
					max = ccon_w[i][j][k];
				}
			}
		}
	}
}


double approxloss(posterior p, classassign c, int is_main = 0) {
	double min_loss;
	min_loss = default_min_loss;
	int tmp = default_action;
	int index_tmp = 0;
	if (default_action<-10000) {
		if (is_main) {
			act = tmp+10000;
			printf("hazard\t\t%d\n",act);
		}
		return default_min_loss;
	}
	int check_view[n_sample_of_action] = {};
	for (int i = 0; i < n_sample_of_action; i++) {
		double loss_tmp = 0.0;
		loss_tmp = beta * log(max( p.post_prob[index_tmp][c.assign[index_tmp][i]][1]/th, 1.));
		if (loss_tmp < min_loss) {
			min_loss = loss_tmp;
			tmp = i;
		}
	}
	double min_post_prob2 = 1e9;
	int cnt = 0;
	for(int i = 0; i < 252; i++) {
		double loss_tmp = beta * log(max(p.post_prob[index_tmp][c.assign[index_tmp][i]][1]/th, 1.));
			if(loss_tmp < min_loss + 1e-9)
				cnt++;
		}
	if (is_main) {
		double sum_unnormalized_prob = 0;
		double unnormalized_prob[400][400] = {};
		double max_tmp = -1e9;
		
		for(int i = 0; i < 376; i++) {
			for(int j = 0; j < 376; j++) {
				int ref_sample = 126 * (i >= 188) + (j > 250 ? 125 : (j > 125 ? j-125 : 0));
				double loss_tmp = beta * log(max(p.post_prob[index_tmp][c.assign[index_tmp][ref_sample]][1]/th, 1.));
				if(loss_tmp < min_loss + 1e-9)
					max_tmp = max_tmp > -(i*0.016-3-mu[0])*(i*0.016-3-mu[0])/sigma[0]/sigma[0]-(j*0.016-3-mu[1])*(j*0.016-3-mu[1])/sigma[1]/sigma[1]?max_tmp:-(i*0.016-3-mu[0])*(i*0.016-3-mu[0])/sigma[0]/sigma[0]-(j*0.016-3-mu[1])*(j*0.016-3-mu[1])/sigma[1]/sigma[1];
			}
		}
		for(int i=0;i<376;i++) {
			for(int j=0;j<376;j++) {
				int ref_sample=126*(i>=188)+(j>250?125:(j>125?j-125:0));
				double loss_tmp = beta * log(max( p.post_prob[index_tmp][c.assign[index_tmp][ref_sample]][1]/th, 1.));
				check_view[ref_sample]=1;
				if (loss_tmp<min_loss+1e-9) {
					unnormalized_prob[i][j]=exp(-(i*0.016-3-mu[0])*(i*0.016-3-mu[0])/sigma[0]/sigma[0]-(j*0.016-3-mu[1])*(j*0.016-3-mu[1])/sigma[1]/sigma[1]-max_tmp);
					sum_unnormalized_prob+=unnormalized_prob[i][j];
				}
			}
		}
		if (cnt) {
			std::random_device rd;
			std::uniform_real_distribution<double> distr(0, sum_unnormalized_prob);
			double sample=distr(rd);
			for(int i=0;i<376;i++) {
				for(int j=0;j<376;j++) {
					sample-=unnormalized_prob[i][j];
					if(sample<0) {
						tmp=i*376+j;
						goto A;
					}
				}
			}
		}
	}

	A:for(int i=0;i<n_class;i++)
		min_post_prob2=min_post_prob2<p.post_prob[index_tmp][i][0]?min_post_prob2:p.post_prob[index_tmp][i][0];
	if (is_main)
	{
		act = tmp;
		default_min_loss=default_min_loss*0.99999+(min_loss+30.0/(min_loss+1))*0.00001;
	}
	if (is_main&&key_1==61011+pof_section)
		printf("%d\t%lf\t%lf\t%lf\t%lf\n",cnt,min_loss,p.post_prob[index_tmp][0][1],p.post_prob[index_tmp][1][1], min_post_prob2);
	return min_loss+5*(min_post_prob2<1?min_post_prob2:1)-cnt*0.0015;
}


void calc_grad_ccon() {
	for (int i = 0; i < n_NN; i++) {
		double approx_loss_class[5];
		for (int j = 0; j < n_sample_of_action; j++) {
			double expsum = 0;
			for (int k = 0; k < n_class; k++) grad_ccon[i][j][k] = 0;
			 {
				for (int k = 0; k < n_class; k++) {
					classassign newclass = normalclass;
					newclass.assign[i][j] = k;
					approx_loss_class[k] = approxloss(calculate_posterior(normaltable), newclass);
					expsum += exp(ccon_w[i][j][k]);
				}
				for (int k = 0; k < n_class; k++) {
					for (int l = 0; l < n_class; l++) {
						if (l == k) grad_ccon[i][j][k] += approx_loss_class[l] * (exp(ccon_w[i][j][l]) / expsum);
						grad_ccon[i][j][k] += -approx_loss_class[l] * (exp(ccon_w[i][j][l]) * exp(ccon_w[i][j][k]) / expsum / expsum);
					}
					if (grad_ccon[i][j][k]>10 || grad_ccon[i][j][k]<-10) printf("%lf\t%lf\t%lf\t%lf\n",exp(ccon_w[i][j][0]),exp(ccon_w[i][j][1]),exp(ccon_w[i][j][2]),expsum);
					if (grad_ccon[i][j][k]>0.0005)	grad_ccon[i][j][k]=0.0005;
					if (grad_ccon[i][j][k]<-0.0005)	grad_ccon[i][j][k]=-0.0005;

				}
			}
		}
	}
}


void calc_grad_test_case_ccon() {
	for (int pm = 0; pm < n_NN; pm++) {
		double plusoneone_approxloss[5][10], plusone_minusxi_approxloss[5][10], plusone_plusxi_approxloss[5][10], minusoneone_approxloss[5][10], minusone_minusxi_approxloss[5][10], minusone_plusxi_approxloss[5][10];
		for (int i = 0; i < n_rank; i++) {
			for (int j = 0; j < n_class; j++) {
				table newtable = normaltable;
				// +(1,1)
				newtable.results[pm][i][j][0]++;
				newtable.results[pm][i][j][1]++;
				plusoneone_approxloss[i][j] = approxloss(calculate_posterior(newtable), normalclass);
				// +(1,0)
				newtable.results[pm][i][j][1]--;
				plusone_minusxi_approxloss[i][j] = approxloss(calculate_posterior(newtable), normalclass);
				// +(0,1)
				newtable.results[pm][i][j][0]--;
				newtable.results[pm][i][j][1]++;
				plusone_plusxi_approxloss[i][j] = approxloss(calculate_posterior(newtable), normalclass);
				
				newtable.results[pm][i][j][1]--;
				
				// -(0,1)
				if (newtable.results[pm][i][j][1]) {
					newtable.results[pm][i][j][1]--;
					minusone_plusxi_approxloss[i][j] = approxloss(calculate_posterior(newtable), normalclass);
					newtable.results[pm][i][j][1]++;
				}
				// -(1,0)
				if (newtable.results[pm][i][j][0]) {
					newtable.results[pm][i][j][0]--;
					minusone_minusxi_approxloss[i][j] = approxloss(calculate_posterior(newtable), normalclass);
					newtable.results[pm][i][j][0]++;
				}
				// -(1,1)
				if (newtable.results[pm][i][j][0] && newtable.results[pm][i][j][1]) {
					newtable.results[pm][i][j][0]--;
					newtable.results[pm][i][j][1]--;
					minusoneone_approxloss[i][j] = approxloss(calculate_posterior(newtable), normalclass);
				}
			}
						
		}
		int n_test_case_pm = const_one_label_intest_size_per_sample*n_rank;
		for (int j = 0; j < n_test_case_pm; j++) {
			double expsum = 0;
			for (int i = 0; i < n_class; i++) expsum += exp(test_case_ccon_w[pm][j][i]);
			for (int i = 0; i < n_class; i++) {
				grad_test_case_ccon[pm][j][i] = 0;
				for (int k = 0; k < n_class; k++) {
					//calculate expsum for +-xi
					double mxi_expsum = expsum + exp(test_case_ccon_w[pm][j][k] - xi) - exp(test_case_ccon_w[pm][j][k]);
					double pxi_expsum = expsum + exp(test_case_ccon_w[pm][j][k] + xi) - exp(test_case_ccon_w[pm][j][k]);
					if (normt_eff[pm][j][k][0] && normt_eff[pm][j][k][1]) {
						//(1,1)
						if (k == i) {
							// (1 1) main*softmax(-xi)	(0 1) minusone_minusxi*(softmax(+xi)-softmax(-xi))	(0 0) minusoneone*(1-softmax(+xi))
							grad_test_case_ccon[pm][j][i] += (main_approxloss - minusone_minusxi_approxloss[ans[pm][j]][k]) * exp(test_case_ccon_w[pm][j][k] - xi) / mxi_expsum;
							grad_test_case_ccon[pm][j][i] += (minusone_minusxi_approxloss[ans[pm][j]][k] - minusoneone_approxloss[ans[pm][j]][k]) * exp(test_case_ccon_w[pm][j][k] + xi) / pxi_expsum;
							grad_test_case_ccon[pm][j][i] += -(main_approxloss - minusone_minusxi_approxloss[ans[pm][j]][k]) * exp(test_case_ccon_w[pm][j][k] - xi) * exp(test_case_ccon_w[pm][j][i] - xi) / mxi_expsum / mxi_expsum;
							grad_test_case_ccon[pm][j][i] += -(minusone_minusxi_approxloss[ans[pm][j]][k] - minusoneone_approxloss[ans[pm][j]][k]) * exp(test_case_ccon_w[pm][j][k] + xi) * exp(test_case_ccon_w[pm][j][i] + xi) / pxi_expsum / pxi_expsum;
						}
						else {
							grad_test_case_ccon[pm][j][i] += -(main_approxloss - minusone_minusxi_approxloss[ans[pm][j]][k]) * exp(test_case_ccon_w[pm][j][k] - xi) * exp(test_case_ccon_w[pm][j][i]) / mxi_expsum / mxi_expsum;
							grad_test_case_ccon[pm][j][i] += -(minusone_minusxi_approxloss[ans[pm][j]][k] - minusoneone_approxloss[ans[pm][j]][k]) * exp(test_case_ccon_w[pm][j][k] + xi) * exp(test_case_ccon_w[pm][j][i]) / pxi_expsum / pxi_expsum;
						}
					}
					else if (normt_eff[pm][j][k][1]) {
						//(0,1)
						if (k == i) {
							// (1 1) plusone_minusxi*softmax(-xi)	(0 1) main*(softmax(+xi)-softmax(-xi))	(0 0) minusone_plusxi*(1-softmax(+xi))
							grad_test_case_ccon[pm][j][i] += (plusone_minusxi_approxloss[ans[pm][j]][k] - main_approxloss) * exp(test_case_ccon_w[pm][j][k] - xi) / mxi_expsum;
							grad_test_case_ccon[pm][j][i] += (main_approxloss - minusone_plusxi_approxloss[ans[pm][j]][k]) * exp(test_case_ccon_w[pm][j][k] + xi) / pxi_expsum;
							grad_test_case_ccon[pm][j][i] += -(plusone_minusxi_approxloss[ans[pm][j]][k] - main_approxloss) * exp(test_case_ccon_w[pm][j][k] - xi) * exp(test_case_ccon_w[pm][j][i] - xi) / mxi_expsum / mxi_expsum;
							grad_test_case_ccon[pm][j][i] += -(main_approxloss - minusone_plusxi_approxloss[ans[pm][j]][k]) * exp(test_case_ccon_w[pm][j][k] + xi) * exp(test_case_ccon_w[pm][j][i] + xi) / pxi_expsum / pxi_expsum;
						}
						else {
							grad_test_case_ccon[pm][j][i] += -(plusone_minusxi_approxloss[ans[pm][j]][k] - main_approxloss) * exp(test_case_ccon_w[pm][j][k] - xi) * exp(test_case_ccon_w[pm][j][i]) / mxi_expsum / mxi_expsum;
							grad_test_case_ccon[pm][j][i] += -(main_approxloss - minusone_plusxi_approxloss[ans[pm][j]][k]) * exp(test_case_ccon_w[pm][j][k] + xi) * exp(test_case_ccon_w[pm][j][i]) / pxi_expsum / pxi_expsum;
						}
					}
					else {
						//(0,0)
						if (k == i) {
							// (1 1) plusoneone*softmax(-xi)	(0 1) plusone_plusxi*(softmax(+xi)-softmax(-xi))	(0 0) main*(1-softmax(+xi))
							grad_test_case_ccon[pm][j][i] += (plusoneone_approxloss[ans[pm][j]][k] - plusone_plusxi_approxloss[ans[pm][j]][k]) * exp(test_case_ccon_w[pm][j][k] - xi) / mxi_expsum;
							grad_test_case_ccon[pm][j][i] += (plusone_plusxi_approxloss[ans[pm][j]][k] - main_approxloss) * exp(test_case_ccon_w[pm][j][k] + xi) / pxi_expsum;
							grad_test_case_ccon[pm][j][i] += -(plusoneone_approxloss[ans[pm][j]][k] - plusone_plusxi_approxloss[ans[pm][j]][k]) * exp(test_case_ccon_w[pm][j][k] - xi) * exp(test_case_ccon_w[pm][j][i] - xi) / mxi_expsum / mxi_expsum;
							grad_test_case_ccon[pm][j][i] += -(plusone_plusxi_approxloss[ans[pm][j]][k] - main_approxloss) * exp(test_case_ccon_w[pm][j][k] + xi) * exp(test_case_ccon_w[pm][j][i] + xi) / pxi_expsum / pxi_expsum;
						}
						else {
							grad_test_case_ccon[pm][j][i] += -(plusoneone_approxloss[ans[pm][j]][k] - plusone_plusxi_approxloss[ans[pm][j]][k]) * exp(test_case_ccon_w[pm][j][k] - xi) * exp(test_case_ccon_w[pm][j][i]) / mxi_expsum / mxi_expsum;
							grad_test_case_ccon[pm][j][i] += -(plusone_plusxi_approxloss[ans[pm][j]][k] - main_approxloss) * exp(test_case_ccon_w[pm][j][k] + xi) * exp(test_case_ccon_w[pm][j][i]) / pxi_expsum / pxi_expsum;
						}
					}
				}
			}
		}
	}
}


void initialize_variables() {
	// global variable section 2
	mode = -10000;
	memset(default_action_batch, 0, sizeof(default_action_batch));
	memset(mu_batch, 0, sizeof(mu_batch));
	memset(sigma_batch, 0, sizeof(sigma_batch));
	n_class = -10000;
	memset(ccon_w_batch, 0, sizeof(ccon_w_batch));
	memset(n_test_case, 0, sizeof(n_test_case));
	memset(ans_batch, 0, sizeof(ans_batch));
	memset(test_case_ccon_w_batch, 0, sizeof(test_case_ccon_w_batch));
	memset(act_batch, 0, sizeof(act_batch));
	memset(grad_ccon_batch, 0, sizeof(grad_ccon_batch));
	memset(grad_test_case_ccon_batch, 0, sizeof(grad_test_case_ccon_batch));
	// xi-planning
	xi = min((sqrt(xi) + sqrt_xi_interval)*(sqrt(xi) + sqrt_xi_interval), xi_end);
}
void initialize_variables_for_batch() {
	// global variable section 3
	default_action = -10000;
	mu[0] = 0;
	mu[1] = 0;
	sigma[0] = 0;
	sigma[1] = 0;
	memset(ccon_w, 0, sizeof(ccon_w));
	memset(ans, 0, sizeof(ans));
	memset(ans_sum, 0, sizeof(ans_sum));
	memset(test_case_ccon_w, 0, sizeof(test_case_ccon_w));
	act = -10000;
	main_approxloss = -10000;
	memset(grad_ccon, 0, sizeof(grad_ccon));
	memset(grad_test_case_ccon, 0, sizeof(grad_test_case_ccon));
	normaltable.reset();
	normalclass.reset();
	memset(normt_eff, 0, sizeof(normt_eff));
}


int main(int argc, char** argv)
{
	int tmp;
	tmp = stoi(argv[1]);
	string grad_path = "/home/user/RL_SYS/grad";
	grad_path += to_string(pof_section);	
	FILE* POSTERIOR = fopen((grad_path + "/POSTERIOR_"+to_string(tmp)+".txt").c_str(), "a+");
	FILE* TRAINLOSS = fopen((grad_path + "/PURE_LOSS_"+to_string(tmp)+".txt").c_str(), "a+");
	FILE* VALLOSS = fopen((grad_path + "/VAL_LOSS_"+to_string(tmp)+".txt").c_str(), "a+");
	signal(SIGINT, signal_callback_handler);
	key_1 = tmp*(61011+pof_section);
	key_2 = tmp*(61022+pof_section);
	pof_shared_memory_1.setKey(key_1);
	pof_shared_memory_1.setMem(permission_1, 820000); // u-g-o(rw-w-w) => 602
	pof_shared_memory_2.setKey(key_2);
	pof_shared_memory_2.setMem(permission_2, 360000); // u-g-o(rw-r-r) => 604

	cout << "xi_start:" << xi_start << ", xi_end:" << xi_end << ", total_epoch:" << total_epoch << ", sqrt_xi_interval:" << sqrt_xi_interval << endl;
	if (checkpoint > 0) {
		xi = (sqrt_xi_interval*checkpoint*1000)*(sqrt_xi_interval*checkpoint*1000);
		cout << "xi initialization complete " << xi << endl;
		if (xi >= xi_end) xi = xi_end;
		cout << "xi clipping complete " << xi << endl;
	}
	sem_init(&sem_1, key_1);
	sem_init(&sem_2, key_2);
	sem_wait(&sem_1);
	sem_wait(&sem_2);
	cout << "STAND BY - S1-0 & S2-0 !!!" << endl;
	srand((unsigned int)time(NULL));
	while (true) {
		initialize_variables();
		//cout << "OUTPUT WAIT ---" << endl; 
		sem_wait(&sem_1);
		pof_shared_memory_1.readMem();
		pof_output_reading(); // S1-0 & S2-0
		//cout << "OUTPUT READING COMPLETE" << endl; 
		for (int iter = 0; iter < n_batch; iter++) {
			initialize_variables_for_batch();
			// prepare for processing of one batch
			default_action = default_action_batch[iter];
			mu[0]=mu_batch[iter][0];
			mu[1]=mu_batch[iter][1];
			sigma[0]=sigma_batch[iter][0];
			sigma[1]=sigma_batch[iter][1];
			for (int j = 0; j < n_sample_of_action; j++) {
				for (int k = 0; k < n_class; k++) ccon_w[0][j][k] = ccon_w_batch[0][iter][j][k]/2.0;
			}
			for (int j = 0; j < const_one_label_intest_size_per_sample*n_rank; j++) {
				ans[0][j] = ans_batch[0][iter][j];
				ans_sum[0][ans[0][j]]++; 
				for (int k = 0; k < n_class; k++) test_case_ccon_w[0][j][k] = test_case_ccon_w_batch[0][iter][j][k]/2.0;
			}
			// calculate one batch loss
			if (mode == -1) break;
			calc_normaltable();
			calc_normalclass();
			posterior posterior_tmp = calculate_posterior(normaltable);
			main_approxloss = approxloss(calculate_posterior(normaltable), normalclass, 1);
			if(key_1==61011+pof_section) printf("%lf\n", main_approxloss);
			// calculate one batch gradient
			if (mode == 0) {
				fprintf(TRAINLOSS, "%s\n", to_string(main_approxloss).c_str());
				calc_grad_ccon();
				calc_grad_test_case_ccon();
			}
			else fprintf(VALLOSS, "%s\n", to_string(main_approxloss).c_str());
			// copy one batch data
			act_batch[iter] = act;
			for (int j = 0; j < n_sample_of_action; j++) {
				for (int k = 0; k < n_class; k++) {
					grad_ccon_batch[0][iter][j][k] = grad_ccon[0][j][k];
				}
			}
			for (int j = 0; j < const_one_label_intest_size_per_sample*n_rank; j++) {
				for (int k = 0; k < n_class; k++) {
					grad_test_case_ccon_batch[0][iter][j][k] = grad_test_case_ccon[0][j][k];
				}
			}
		}
		
		pof_update_writing();
		//cout << "UPDATE WRITING COMPLETE" << endl;
		sem_post(&sem_2); // S1-0 & S2-1
	}
	sem_destroy(&sem_1);
	sem_destroy(&sem_2);
	fclose(POSTERIOR);
	fclose(TRAINLOSS);
	cout << "DESTROY S1 & S2" << endl;
	pof_shared_memory_1.close();
	pof_shared_memory_2.close();
	cout << "DESTROY SHM1 & SHM2" << endl;
	return 0;
}